from textwrap import dedent
from operator import eq, ge, lt
import re
import os

import pytest

from jedi.inference.gradual.conversion import _stub_to_python_value_set
from ..helpers import get_example_dir


@pytest.mark.parametrize(
    'code, sig, names, op, version', [
        ('import math; math.cos', 'cos(x, /)', ['x'], ge, (3, 6)),

        ('next', 'next(iterator, default=None, /)', ['iterator', 'default'], lt, (3, 12)),
        ('next', 'next()', [], eq, (3, 12)),
        ('next', 'next(iterator, default=None, /)', ['iterator', 'default'], ge, (3, 13)),

        ('str', "str(object='', /) -> str", ['object'], ge, (3, 6)),

        ('pow', 'pow(base, exp, mod=None)', ['base', 'exp', 'mod'], ge, (3, 8)),

        ('bytes.partition', 'partition(self, sep, /)', ['self', 'sep'], ge, (3, 6)),

        ('bytes().partition', 'partition(sep, /)', ['sep'], ge, (3, 6)),
    ]
)
def test_compiled_signature(Script, environment, code, sig, names, op, version):
    if not op(environment.version_info, version):
        return  # The test right next to it should take over.

    d, = Script(code).infer()
    value, = d._name.infer()
    compiled, = _stub_to_python_value_set(value)
    signature, = compiled.get_signatures()
    assert signature.to_string() == sig
    assert [n.string_name for n in signature.get_param_names()] == names


classmethod_code = '''
class X:
    @classmethod
    def x(cls, a, b):
        pass

    @staticmethod
    def static(a, b):
        pass
'''


partial_code = '''
import functools

def func(a, b, c):
    pass

a = functools.partial(func)
b = functools.partial(func, 1)
c = functools.partial(func, 1, c=2)
d = functools.partial()
'''


partialmethod_code = '''
import functools

class X:
    def func(self, a, b, c):
        pass
    a = functools.partialmethod(func)
    b = functools.partialmethod(func, 1)
    c = functools.partialmethod(func, 1, c=2)
    d = functools.partialmethod()
'''


@pytest.mark.parametrize(
    'code, expected', [
        ('def f(a, * args, x): pass\n f(', 'f(a, *args, x)'),
        ('def f(a, *, x): pass\n f(', 'f(a, *, x)'),
        ('def f(*, x= 3,**kwargs): pass\n f(', 'f(*, x=3, **kwargs)'),
        ('def f(x,/,y,* ,z): pass\n f(', 'f(x, /, y, *, z)'),
        ('def f(a, /, *, x=3, **kwargs): pass\n f(', 'f(a, /, *, x=3, **kwargs)'),

        (classmethod_code + 'X.x(', 'x(a, b)'),
        (classmethod_code + 'X().x(', 'x(a, b)'),
        (classmethod_code + 'X.static(', 'static(a, b)'),
        (classmethod_code + 'X().static(', 'static(a, b)'),

        (partial_code + 'a(', 'func(a, b, c)'),
        (partial_code + 'b(', 'func(b, c)'),
        (partial_code + 'c(', 'func(b)'),
        (partial_code + 'd(', None),

        (partialmethod_code + 'X().a(', 'func(a, b, c)'),
        (partialmethod_code + 'X().b(', 'func(b, c)'),
        (partialmethod_code + 'X().c(', 'func(b)'),
        (partialmethod_code + 'X().d(', None),
        (partialmethod_code + 'X.c(', 'func(a, b)'),
        (partialmethod_code + 'X.d(', None),

        ('import contextlib\n@contextlib.contextmanager\ndef f(x): pass\nf(', 'f(x)'),

        # typing lib
        ('from typing import cast\ncast(', {
            'cast(typ: object, val: Any) -> Any',
            'cast(typ: str, val: Any) -> Any',
            'cast(typ: Type[_T], val: Any) -> _T'}),
        ('from typing import TypeVar\nTypeVar(',
         'TypeVar(name: str, *constraints: Type[Any], bound: Union[None, Type[Any], str]=..., '
         'covariant: bool=..., contravariant: bool=...)'),
        ('from typing import List\nList(', None),
        ('from typing import List\nList[int](', None),
        ('from typing import Tuple\nTuple(', None),
        ('from typing import Tuple\nTuple[int](', None),
        ('from typing import Optional\nOptional(', None),
        ('from typing import Optional\nOptional[int](', None),
        ('from typing import Any\nAny(', None),
        ('from typing import NewType\nNewType(', 'NewType(name: str, tp: Type[_T]) -> Type[_T]'),
    ]
)
def test_tree_signature(Script, environment, code, expected):
    # Only test this in the latest version, because of /
    if environment.version_info < (3, 8):
        pytest.skip()

    if expected is None:
        assert not Script(code).get_signatures()
    else:
        actual = {sig.to_string() for sig in Script(code).get_signatures()}
        if not isinstance(expected, set):
            expected = {expected}
        assert expected == actual


@pytest.mark.parametrize(
    'combination, expected', [
        # Functions
        ('full_redirect(simple)', 'b, *, c'),
        ('full_redirect(simple4)', 'b, x: int'),
        ('full_redirect(a)', 'b, *args'),
        ('full_redirect(kw)', 'b, *, c, **kwargs'),
        ('full_redirect(akw)', 'c, *args, **kwargs'),

        # Non functions
        ('full_redirect(lambda x, y: ...)', 'y'),
        ('full_redirect()', '*args, **kwargs'),
        ('full_redirect(1)', '*args, **kwargs'),

        # Classes / inheritance
        ('full_redirect(C)', 'z, *, c'),
        ('full_redirect(C())', 'y'),
        ('full_redirect(G)', 't: T'),
        ('full_redirect(G[str])', '*args, **kwargs'),
        ('D', 'D(a, z, /)'),
        ('D()', 'D(x, y)'),
        ('D().foo', 'foo(a, *, bar, z, **kwargs)'),

        # Merging
        ('two_redirects(simple, simple)', 'a, b, *, c'),
        ('two_redirects(simple2, simple2)', 'x'),
        ('two_redirects(akw, kw)', 'a, c, *args, **kwargs'),
        ('two_redirects(kw, akw)', 'a, b, *args, c, **kwargs'),

        ('two_kwargs_redirects(simple, simple)', '*args, a, b, c'),
        ('two_kwargs_redirects(kw, kw)', '*args, a, b, c, **kwargs'),
        ('two_kwargs_redirects(simple, kw)', '*args, a, b, c, **kwargs'),
        ('two_kwargs_redirects(simple2, two_kwargs_redirects(simple, simple))',
         '*args, x, a, b, c'),

        ('combined_redirect(simple, simple2)', 'a, b, /, *, x'),
        ('combined_redirect(simple, simple3)', 'a, b, /, *, a, x: int'),
        ('combined_redirect(simple2, simple)', 'x, /, *, a, b, c'),
        ('combined_redirect(simple3, simple)', 'a, x: int, /, *, a, b, c'),

        ('combined_redirect(simple, kw)', 'a, b, /, *, a, b, c, **kwargs'),
        ('combined_redirect(kw, simple)', 'a, b, /, *, a, b, c'),
        ('combined_redirect(simple, simple2)', 'a, b, /, *, x'),

        ('combined_lot_of_args(kw, simple4)', '*, b'),
        ('combined_lot_of_args(simple4, kw)', '*, b, c, **kwargs'),

        ('combined_redirect(combined_redirect(simple2, simple4), combined_redirect(kw, simple5))',
         'x, /, *, y'),
        ('combined_redirect(combined_redirect(simple4, simple2), combined_redirect(simple5, kw))',
         'a, b, x: int, /, *, a, b, c, **kwargs'),
        ('combined_redirect(combined_redirect(a, kw), combined_redirect(kw, simple5))',
         'a, b, /, *args, y'),

        ('no_redirect(kw)', '*args, **kwargs'),
        ('no_redirect(akw)', '*args, **kwargs'),
        ('no_redirect(simple)', '*args, **kwargs'),
    ]
)
def test_nested_signatures(Script, environment, combination, expected):
    code = dedent('''
        def simple(a, b, *, c): ...
        def simple2(x): ...
        def simple3(a, x: int): ...
        def simple4(a, b, x: int): ...
        def simple5(y): ...
        def a(a, b, *args): ...
        def kw(a, b, *, c, **kwargs): ...
        def akw(a, c, *args, **kwargs): ...

        def no_redirect(func):
            return lambda *args, **kwargs: func(1)
        def full_redirect(func):
            return lambda *args, **kwargs: func(1, *args, **kwargs)
        def two_redirects(func1, func2):
            return lambda *args, **kwargs: func1(*args, **kwargs) + func2(1, *args, **kwargs)
        def two_kwargs_redirects(func1, func2):
            return lambda *args, **kwargs: func1(**kwargs) + func2(1, **kwargs)
        def combined_redirect(func1, func2):
            return lambda *args, **kwargs: func1(*args) + func2(**kwargs)
        def combined_lot_of_args(func1, func2):
            return lambda *args, **kwargs: func1(1, 2, 3, 4, *args) + func2(a=3, x=1, y=1, **kwargs)

        class C:
            def __init__(self, a, z, *, c): ...
            def __call__(self, x, y): ...

            def foo(self, bar, z, **kwargs): ...

        class D(C):
            def __init__(self, *args):
                super().__init__(*args)

            def foo(self, a, **kwargs):
                super().foo(**kwargs)

        from typing import Generic, TypeVar
        T = TypeVar('T')
        class G(Generic[T]):
            def __init__(self, i, t: T): ...
    ''')
    code += 'z = ' + combination + '\nz('
    sig, = Script(code).get_signatures()
    computed = sig.to_string()
    if not re.match(r'\w+\(', expected):
        expected = '<lambda>(' + expected + ')'
    assert expected == computed


def test_pow_signature(Script, environment):
    # See github #1357
    sigs = Script('pow(').get_signatures()
    strings = {sig.to_string() for sig in sigs}
    if environment.version_info < (3, 8):
        assert strings == {'pow(base: _SupportsPow2[_E, _T_co], exp: _E, /) -> _T_co',
                           'pow(base: _SupportsPow3[_E, _M, _T_co], exp: _E, mod: _M, /) -> _T_co',
                           'pow(base: float, exp: float, mod: None=..., /) -> float',
                           'pow(base: int, exp: int, mod: None=..., /) -> Any',
                           'pow(base: int, exp: int, mod: int, /) -> int'}
    else:
        assert strings == {'pow(base: _SupportsPow2[_E, _T_co], exp: _E) -> _T_co',
                           'pow(base: _SupportsPow3[_E, _M, _T_co], exp: _E, mod: _M) -> _T_co',
                           'pow(base: float, exp: float, mod: None=...) -> float',
                           'pow(base: int, exp: int, mod: None=...) -> Any',
                           'pow(base: int, exp: int, mod: int) -> int'}


@pytest.mark.parametrize(
    'code, signature', [
        [dedent('''
            # identifier:A
            import functools
            def f(x):
                pass
            def x(f):
                @functools.wraps(f)
                def wrapper(*args):
                    return f(*args)
                return wrapper

            x(f)('''), 'f(x, /)'],
        [dedent('''
            # identifier:B
            import functools
            def f(x):
                pass
            def x(f):
                @functools.wraps(f)
                def wrapper():
                    # Have no arguments here, but because of wraps, the signature
                    # should still be f's.
                    return 1
                return wrapper

            x(f)('''), 'f()'],
        [dedent('''
            # identifier:C
            import functools
            def f(x: int, y: float):
                pass

            @functools.wraps(f)
            def wrapper(*args, **kwargs):
                return f(*args, **kwargs)

            wrapper('''), 'f(x: int, y: float)'],
        [dedent('''
            # identifier:D
            def f(x: int, y: float):
                pass

            def wrapper(*args, **kwargs):
                return f(*args, **kwargs)

            wrapper('''), 'wrapper(x: int, y: float)'],
    ]
)
def test_wraps_signature(Script, code, signature):
    sigs = Script(code).get_signatures()
    assert {sig.to_string() for sig in sigs} == {signature}


@pytest.mark.parametrize(
    "start, start_params, include_params",
    [
        ["@dataclass\nclass X:", [], True],
        ["@dataclass(eq=True)\nclass X:", [], True],
        [
            dedent(
                """
         class Y():
             y: int
         @dataclass
         class X(Y):"""
            ),
            [],
            True,
        ],
        [
            dedent(
                """
         @dataclass
         class Y():
             y: int
             z = 5
         @dataclass
         class X(Y):"""
            ),
            ["y"],
            True,
        ],
        [
            dedent(
                """
         @dataclass
         class Y():
             y: int
         class Z(Y): # Not included
             z = 5
         @dataclass
         class X(Z):"""
            ),
            ["y"],
            True,
        ],
        # init=False
        [
            dedent(
                """
            @dataclass(init=False)
            class X:"""
            ),
            [],
            False,
        ],
        [
            dedent(
                """
            @dataclass(eq=True, init=False)
            class X:"""
            ),
            [],
            False,
        ],
        # custom init
        [
            dedent(
                """
        @dataclass()
        class X:
            def __init__(self, toto: str):
                pass
            """
            ),
            ["toto"],
            False,
        ],
    ],
    ids=[
        "direct_transformed",
        "transformed_with_params",
        "subclass_transformed",
        "both_transformed",
        "intermediate_not_transformed",
        "init_false",
        "init_false_multiple",
        "custom_init",
    ],
)
def test_dataclass_signature(
    Script, skip_pre_python37, start, start_params, include_params, environment
):
    if environment.version_info < (3, 8):
        # Final is not yet supported
        price_type = "float"
        price_type_infer = "float"
    else:
        price_type = "Final[float]"
        price_type_infer = "object"

    code = dedent(
        f"""
            name: str
            foo = 3
            blob: ClassVar[str]
            price: {price_type}
            quantity: int = 0.0

        X("""
    )

    code = (
        "from dataclasses import dataclass\n"
        + "from typing import ClassVar, Final\n"
        + start
        + code
    )

    sig, = Script(code).get_signatures()
    expected_params = (
        [*start_params, "name", "price", "quantity"]
        if include_params
        else [*start_params]
    )
    assert [p.name for p in sig.params] == expected_params

    if include_params:
        quantity, = sig.params[-1].infer()
        assert quantity.name == 'int'
        price, = sig.params[-2].infer()
        assert price.name == price_type_infer


dataclass_transform_cases = [
    # Attributes on the decorated class and its base classes
    # are not considered to be fields.
    # 1/ Declare dataclass transformer
    # Base Class
    ['@dataclass_transform\nclass X:', [], False],
    # Base Class with params
    ['@dataclass_transform(eq_default=True)\nclass X:', [], False],
    # Subclass
    [dedent('''
        class Y():
            y: int
        @dataclass_transform
        class X(Y):'''), [], False],
    # 2/ Declare dataclass transformed
    # Class based
    [dedent('''
        @dataclass_transform
        class Y():
            y: int
            z = 5
        class X(Y):'''), [], True],
    # Class based with params
    [dedent('''
        @dataclass_transform(eq_default=True)
        class Y():
            y: int
            z = 5
        class X(Y):'''), [], True],
    # Decorator based
    [dedent('''
        @dataclass_transform
        def create_model():
            pass
        @create_model
        class X:'''), [], True],
    [dedent('''
        @dataclass_transform
        def create_model():
            pass
        class Y:
            y: int
        @create_model
        class X(Y):'''), [], True],
    [dedent('''
        @dataclass_transform
        def create_model():
            pass
        @create_model
        class Y:
            y: int
        @create_model
        class X(Y):'''), ["y"], True],
    [dedent('''
        @dataclass_transform
        def create_model():
            pass
        @create_model
        class Y:
            y: int
        class Z(Y):
            z: int
        @create_model
        class X(Z):'''), ["y"], True],
    # Metaclass based
    [dedent('''
        @dataclass_transform
        class ModelMeta():
            y: int
            z = 5
        class ModelBase(metaclass=ModelMeta):
            t: int
            p = 5
        class X(ModelBase):'''), [], True],
    # 3/ Init custom init
    [dedent('''
    @dataclass_transform()
    class Y():
        y: int
        z = 5
    class X(Y):
        def __init__(self, toto: str):
            pass
        '''), ["toto"], False],
    # 4/ init=false
    # Class based
    # WARNING: Unsupported
    # [dedent('''
    #     @dataclass_transform
    #     class Y():
    #         y: int
    #         z = 5
    #         def __init_subclass__(
    #             cls,
    #             *,
    #             init: bool = False,
    #         )
    #     class X(Y):'''), [], False],
    [dedent('''
        @dataclass_transform
        class Y():
            y: int
            z = 5
            def __init_subclass__(
                cls,
                *,
                init: bool = False,
            )
        class X(Y, init=True):'''), [], True],
    [dedent('''
        @dataclass_transform
        class Y():
            y: int
            z = 5
            def __init_subclass__(
                cls,
                *,
                init: bool = False,
            )
        class X(Y, init=False):'''), [], False],
    [dedent('''
        @dataclass_transform
        class Y():
            y: int
            z = 5
        class X(Y, init=False):'''), [], False],
    # Decorator based
    [dedent('''
        @dataclass_transform
        def create_model(init=False):
            pass
        @create_model()
        class X:'''), [], False],
    [dedent('''
        @dataclass_transform
        def create_model(init=False):
            pass
        @create_model(init=True)
        class X:'''), [], True],
    [dedent('''
        @dataclass_transform
        def create_model(init=False):
            pass
        @create_model(init=False)
        class X:'''), [], False],
    [dedent('''
        @dataclass_transform
        def create_model():
            pass
        @create_model(init=False)
        class X:'''), [], False],
    # Metaclass based
    [dedent('''
        @dataclass_transform
        class ModelMeta():
            y: int
            z = 5
            def __new__(
                cls,
                name,
                bases,
                namespace,
                *,
                init: bool = False,
            ):
                ...
        class ModelBase(metaclass=ModelMeta):
            t: int
            p = 5
        class X(ModelBase):'''), [], False],
    [dedent('''
        @dataclass_transform
        class ModelMeta():
            y: int
            z = 5
            def __new__(
                cls,
                name,
                bases,
                namespace,
                *,
                init: bool = False,
            ):
                ...
        class ModelBase(metaclass=ModelMeta):
            t: int
            p = 5
        class X(ModelBase, init=True):'''), [], True],
    [dedent('''
        @dataclass_transform
        class ModelMeta():
            y: int
            z = 5
            def __new__(
                cls,
                name,
                bases,
                namespace,
                *,
                init: bool = False,
            ):
                ...
        class ModelBase(metaclass=ModelMeta):
            t: int
            p = 5
        class X(ModelBase, init=False):'''), [], False],
    [dedent('''
        @dataclass_transform
        class ModelMeta():
            y: int
            z = 5
        class ModelBase(metaclass=ModelMeta):
            t: int
            p = 5
        class X(ModelBase, init=False):'''), [], False],
    # 4/ Other parameters
    # Class based
    [dedent('''
        @dataclass_transform
        class Y():
            y: int
            z = 5
        class X(Y, eq=True):'''), [], True],
    # Decorator based
    [dedent('''
        @dataclass_transform
        def create_model():
            pass
        @create_model(eq=True)
        class X:'''), [], True],
    # Metaclass based
    [dedent('''
        @dataclass_transform
        class ModelMeta():
            y: int
            z = 5
        class ModelBase(metaclass=ModelMeta):
            t: int
            p = 5
        class X(ModelBase, eq=True):'''), [], True],
]

ids = [
    "direct_transformer",
    "transformer_with_params",
    "subclass_transformer",
    "base_transformed",
    "base_transformed_with_params",
    "decorator_transformed_direct",
    "decorator_transformed_subclass",
    "decorator_transformed_both",
    "decorator_transformed_intermediate_not",
    "metaclass_transformed",
    "custom_init",
    # "base_transformed_init_false_dataclass_init_default",
    "base_transformed_init_false_dataclass_init_true",
    "base_transformed_init_false_dataclass_init_false",
    "base_transformed_init_default_dataclass_init_false",
    "decorator_transformed_init_false_dataclass_init_default",
    "decorator_transformed_init_false_dataclass_init_true",
    "decorator_transformed_init_false_dataclass_init_false",
    "decorator_transformed_init_default_dataclass_init_false",
    "metaclass_transformed_init_false_dataclass_init_default",
    "metaclass_transformed_init_false_dataclass_init_true",
    "metaclass_transformed_init_false_dataclass_init_false",
    "metaclass_transformed_init_default_dataclass_init_false",
    "base_transformed_other_parameters",
    "decorator_transformed_other_parameters",
    "metaclass_transformed_other_parameters",
]


@pytest.mark.parametrize(
    'start, start_params, include_params', dataclass_transform_cases, ids=ids
)
def test_extensions_dataclass_transform_signature(
    Script, skip_pre_python37, start, start_params, include_params, environment
):
    has_typing_ext = bool(Script('import typing_extensions').infer())
    if not has_typing_ext:
        raise pytest.skip("typing_extensions needed in target environment to run this test")

    if environment.version_info < (3, 8):
        # Final is not yet supported
        price_type = "float"
        price_type_infer = "float"
    else:
        price_type = "Final[float]"
        price_type_infer = "object"

    code = dedent(
        f"""
            name: str
            foo = 3
            blob: ClassVar[str]
            price: {price_type}
            quantity: int = 0.0

        X("""
    )

    code = (
        "from typing_extensions import dataclass_transform\n"
        + "from typing import ClassVar, Final\n"
        + start
        + code
    )

    (sig,) = Script(code).get_signatures()
    expected_params = (
        [*start_params, "name", "price", "quantity"]
        if include_params
        else [*start_params]
    )
    assert [p.name for p in sig.params] == expected_params

    if include_params:
        quantity, = sig.params[-1].infer()
        assert quantity.name == 'int'
        price, = sig.params[-2].infer()
        assert price.name == price_type_infer


def test_dataclass_transform_complete(Script):
    script = Script('''\
        @dataclass_transform
        class Y():
            y: int
            z = 5

        class X(Y):
            name: str
            foo = 3

        def f(x: X):
            x.na''')
    completion, = script.complete()
    assert completion.description == 'name: str'


@pytest.mark.parametrize(
    "start, start_params, include_params", dataclass_transform_cases, ids=ids
)
def test_dataclass_transform_signature(
    Script, skip_pre_python311, start, start_params, include_params
):
    code = dedent('''
            name: str
            foo = 3
            blob: ClassVar[str]
            price: Final[float]
            quantity: int = 0.0

        X(''')

    code = (
        "from typing import dataclass_transform\n"
        + "from typing import ClassVar, Final\n"
        + start
        + code
    )

    sig, = Script(code).get_signatures()
    expected_params = (
        [*start_params, "name", "price", "quantity"]
        if include_params
        else [*start_params]
    )
    assert [p.name for p in sig.params] == expected_params

    if include_params:
        quantity, = sig.params[-1].infer()
        assert quantity.name == 'int'
        price, = sig.params[-2].infer()
        assert price.name == 'object'


@pytest.mark.parametrize(
    'start, start_params', [
        ['@define\nclass X:', []],
        ['@frozen\nclass X:', []],
        ['@define(eq=True)\nclass X:', []],
        [dedent('''
         class Y():
             y: int
         @define
         class X(Y):'''), []],
        [dedent('''
         @define
         class Y():
             y: int
             z = 5
         @define
         class X(Y):'''), ['y']],
    ],
    ids=["define", "frozen", "define_customized", "define_subclass", "define_both"]
)
def test_attrs_signature(Script, skip_pre_python37, start, start_params):
    has_attrs = bool(Script('import attrs').infer())
    if not has_attrs:
        raise pytest.skip("attrs needed in target environment to run this test")

    code = dedent('''
            name: str
            foo = 3
            price: float
            quantity: int = 0.0

        X(''')

    # attrs exposes two namespaces
    code = 'from attrs import define, frozen\n' + start + code

    sig, = Script(code).get_signatures()
    assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity']
    quantity, = sig.params[-1].infer()
    assert quantity.name == 'int'
    price, = sig.params[-2].infer()
    assert price.name == 'float'


@pytest.mark.parametrize(
    'stmt, expected', [
        ('args = 1', 'wrapped(*args, b, c)'),
        ('args = (1,)', 'wrapped(*args, c)'),
        ('kwargs = 1', 'wrapped(b, /, **kwargs)'),
        ('kwargs = dict(b=3)', 'wrapped(b, /, **kwargs)'),
    ]
)
def test_param_resolving_to_static(Script, stmt, expected):
    code = dedent('''\
        def full_redirect(func):
            def wrapped(*args, **kwargs):
                {stmt}
                return func(1, *args, **kwargs)
            return wrapped
        def simple(a, b, *, c): ...
        full_redirect(simple)('''.format(stmt=stmt))

    sig, = Script(code).get_signatures()
    assert sig.to_string() == expected


@pytest.mark.parametrize(
    'code', [
        'from file import with_overload; with_overload(',
        'from file import *\nwith_overload(',
    ]
)
def test_overload(Script, code):
    dir_ = get_example_dir('typing_overload')
    x1, x2 = Script(code, path=os.path.join(dir_, 'foo.py')).get_signatures()
    assert x1.to_string() == 'with_overload(x: int, y: int) -> float'
    assert x2.to_string() == 'with_overload(x: str, y: list) -> float'


def test_enum(Script):
    script = Script('''\
        from enum import Enum

        class Planet(Enum):
            MERCURY = (3.303e+23, 2.4397e6)
            VENUS = (4.869e+24, 6.0518e6)

            def __init__(self, mass, radius):
                self.mass = mass  # in kilograms
                self.radius = radius  # in meters

        Planet.MERCURY''')
    completion, = script.complete()
    assert not completion.get_signatures()
